{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 上一个代码的输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_pre = {'args': {'肺气肿': ['disease'], '百日咳': ['disease'], '血常规': ['check']}, 'question_types': ['check_disease']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'args': {'肺气肿': ['disease'], '百日咳': ['disease'], '血常规': ['check']},\n",
       " 'question_types': ['check_disease']}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_pre"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = output_pre['args']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_dict = {}\n",
    "for arg, types in args.items():\n",
    "    for type in types:\n",
    "        if type not in entity_dict:\n",
    "            entity_dict[type] = [arg]\n",
    "        else:\n",
    "            entity_dict[type].append(arg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'disease': ['肺气肿', '百日咳'], 'check': ['血常规']}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "question_types = output_pre['question_types']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['check_disease']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'check_disease'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sqls = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sql_transfer(question_type, entities):\n",
    "    '''\n",
    "    不同的提问意图，对应不同关系的cypher查询语句\n",
    "    '''\n",
    "    \n",
    "    if not entities:\n",
    "        return []\n",
    "\n",
    "    # 查询语句\n",
    "    sql = []\n",
    "    # 查询疾病的病因\n",
    "    if question_type == 'disease_cause':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.cause\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的预防措施\n",
    "    elif question_type == 'disease_prevent':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.prevent\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的疗程\n",
    "    elif question_type == 'disease_lasttime':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.cure_lasttime\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的治愈率\n",
    "    elif question_type == 'disease_cureprob':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.cured_prob\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的疗法\n",
    "    elif question_type == 'disease_cureway':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.cure_way\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的易感人群\n",
    "    elif question_type == 'disease_easyget':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.easy_get\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的描述\n",
    "    elif question_type == 'disease_desc':\n",
    "        sql = [\"MATCH (m:Disease) where m.name = '{0}' return m.name, m.desc\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的症状\n",
    "    elif question_type == 'disease_symptom':\n",
    "        sql = [\"MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    # 查询症状对应的疾病\n",
    "    elif question_type == 'symptom_disease':\n",
    "        sql = [\"MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病的并发症\n",
    "    elif question_type == 'disease_acompany':\n",
    "        sql1 = [\"MATCH (m:Disease)-[r:acompany_with]->(n:Disease) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql2 = [\"MATCH (m:Disease)-[r:acompany_with]->(n:Disease) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql = sql1 + sql2\n",
    "    # 查询疾病的忌口\n",
    "    elif question_type == 'disease_not_food':\n",
    "        sql = [\"MATCH (m:Disease)-[r:no_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    # 查询疾病建议吃的东西\n",
    "    elif question_type == 'disease_do_food':\n",
    "        sql1 = [\"MATCH (m:Disease)-[r:do_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql2 = [\"MATCH (m:Disease)-[r:recommand_eat]->(n:Food) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql = sql1 + sql2\n",
    "\n",
    "    # 已知忌口查疾病\n",
    "    elif question_type == 'food_not_disease':\n",
    "        sql = [\"MATCH (m:Disease)-[r:no_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    # 已知推荐查疾病\n",
    "    elif question_type == 'food_do_disease':\n",
    "        sql1 = [\"MATCH (m:Disease)-[r:do_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql2 = [\"MATCH (m:Disease)-[r:recommand_eat]->(n:Food) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql = sql1 + sql2\n",
    "\n",
    "    # 查询疾病常用药品－药品别名记得扩充\n",
    "    elif question_type == 'disease_drug':\n",
    "        sql1 = [\"MATCH (m:Disease)-[r:common_drug]->(n:Drug) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql2 = [\"MATCH (m:Disease)-[r:recommand_drug]->(n:Drug) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql = sql1 + sql2\n",
    "\n",
    "    # 已知药品查询能够治疗的疾病\n",
    "    elif question_type == 'drug_disease':\n",
    "        sql1 = [\"MATCH (m:Disease)-[r:common_drug]->(n:Drug) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql2 = [\"MATCH (m:Disease)-[r:recommand_drug]->(n:Drug) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "        sql = sql1 + sql2\n",
    "    # 查询疾病应该做的检查\n",
    "    elif question_type == 'disease_check':\n",
    "        sql = [\"MATCH (m:Disease)-[r:need_check]->(n:Check) where m.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    # 已知检查查询疾病\n",
    "    elif question_type == 'check_disease':\n",
    "        sql = [\"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '{0}' return m.name, r.name, n.name\".format(i) for i in entities]\n",
    "\n",
    "    return sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '肺气肿' return m.name, r.name, n.name\",\n",
       " \"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '百日咳' return m.name, r.name, n.name\"]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_transfer('check_disease', ['肺气肿', '百日咳'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "sqls = []\n",
    "for question_type in question_types:\n",
    "    sql_ = {}\n",
    "    sql_['question_type'] = question_type\n",
    "    sql_['sql'] = sql_transfer(question_type, entity_dict.get(question_type.split('_')[0]))\n",
    "    sqls.append(sql_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'question_type': 'check_disease',\n",
       "  'sql': [\"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '血常规' return m.name, r.name, n.name\"]}]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sqls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'check_disease'"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for question_type in question_types:\n",
    "#     sql_ = {}\n",
    "#     sql_['question_type'] = question_type\n",
    "#     sql = []\n",
    "#     if question_type == 'disease_symptom':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'symptom_disease':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('symptom'))\n",
    "\n",
    "#     elif question_type == 'disease_cause':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_acompany':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_not_food':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_do_food':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'food_not_disease':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('food'))\n",
    "\n",
    "#     elif question_type == 'food_do_disease':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('food'))\n",
    "\n",
    "#     elif question_type == 'disease_drug':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'drug_disease':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('drug'))\n",
    "\n",
    "#     elif question_type == 'disease_check':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'check_disease':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('check'))\n",
    "\n",
    "#     elif question_type == 'disease_prevent':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_lasttime':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_cureway':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_cureprob':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_easyget':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     elif question_type == 'disease_desc':\n",
    "#         sql = sql_transfer(question_type, entity_dict.get('disease'))\n",
    "\n",
    "#     if sql:\n",
    "#         sql_['sql'] = sql\n",
    "\n",
    "#         sqls.append(sql_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'question_type': 'check_disease',\n",
       "  'sql': [\"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '血常规' return m.name, r.name, n.name\"]}]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sqls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'question_type': 'check_disease',\n",
       "  'sql': [\"MATCH (m:Disease)-[r:need_check]->(n:Check) where n.name = '血常规' return m.name, r.name, n.name\"]}]"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sqls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}